""" Standard Incomplete Market model with Endogenous Labor Supply """

function hh_init(; a_grid, we, r, eis, T)
    # fininc = (1 + r) * a_grid' .+ T .- a_grid[1]
    coh = (1 + r) * a_grid' .+ (we .+ T)
    Va = (1 + r) * (0.1 * coh) .^ (-1 / eis)
    return (Va,)
end

""" Single backward step via EGM. """
function _hh(; Va_p, a_grid, we, T, r, β, eis, ν, φ)
    uc_nextgrid = β * Va_p
    c_nextgrid, n_nextgrid = cn(uc_nextgrid, we, eis, ν, φ)

    lhs = c_nextgrid .- we .* n_nextgrid .+ a_grid' .- T 
    rhs = (1 + r) .* a_grid
    c = interpolate_y(lhs, rhs, c_nextgrid)
    n = interpolate_y(lhs, rhs, n_nextgrid)

    a = rhs' .+ we .* n .+ T .- c
    # a[1, 1] = -59
    iconst = findall(a .< a_grid[1])
    a[iconst] .= a_grid[1]
    if !isempty(iconst)
        iconst_np = [getindex.(iconst, 1), getindex.(iconst, 2)]
        c[iconst], n[iconst] = solve_cn(we[iconst_np[1]], rhs[iconst_np[2]] .+ T[iconst_np[1]] .- a_grid[1], eis, ν, φ, Va_p[iconst])
    end

    Va = (1 + r) * c .^ (-1 / eis)

    return Va, a, c, n
end


hh_labor = het(_hh; exogenous = "Pi", policy = "a", backward = "Va", backward_init = hh_init) 

""" Supporting functions for HA block """

function cn(uc, w, eis, ν, φ)
    """ Return optimal c, n as function of u'(c) given parameters """
    return uc .^ (-eis), (w .* uc ./ φ) .^ ν
end

function solve_cn(w, T, eis, ν, φ, uc_seed)
    uc = solve_uc(w, T, eis, ν, φ, uc_seed)
    return cn(uc, w, eis, ν, φ)
end

function solve_uc(w, T, eis, ν, φ, uc_seed)
    """ Solve for optimal uc given in log uc space.

    max_{c, n} c**(1-1/eis) + φ*n**(1+1/ν) s.t. c == w*n + T
    """
    log_uc = log.(uc_seed)
    does_break = false
    for i in 1:30
        ne, ne_p = netexp(log_uc, w, T, eis, ν, φ)
        if maximum(abs.(ne)) < 1e-11
            does_break=true
            break
        else
            log_uc .-= ne ./ ne_p
        end
    end

    if !does_break
        error("Cannot solve constrained household's problem: No convergence after 30 iterations.")
    end

    return exp.(log_uc)
end

function netexp(log_uc, w, T, eis, ν, φ)
    """ Return net expenditure as a function of log uc and its derivative. """
    uc = exp.(log_uc)
    c, n = cn(uc, w, eis, ν, φ)
    ne = c .- w .* n .- T

    # c and n have elasticities of -eis and ν wrt log u'(c)
    c_loguc = -eis * c
    n_loguc = ν * n
    netexp_loguc = c_loguc .- w .* n_loguc

    return ne, netexp_loguc
end
